
using LinearAlgebra


#=
B1 B2 .. Bl .. Bn-1 Bn
          ^ 比如现在l在某处
G(τ-Δτ,\tau-\Delta\tau) = [I + B(\tau-\Delta\tau,0)B(\beta,\tau)]^{-1}
G(0,0) = [I + B(β,0)]^{-1} = [I + Bn Bn-1 ... B2 B1]^{-1}
G(l,l) = [I + B(l,0)B(β,l+1)]^{-1} = [I +  Bl .. B1  Bn ... Bl+1 ]^{-1}

l只能取stable的间隔
F[l] = Bl .. B1
F[l+1] = Bn ... Bl+1

UDV相乘
UDV UDV = U (DVUD) V = UU D VV

=#

struct ScrollSVD{T}
    B :: Vector{Matrix{T}}
    F :: Vector{SVD{T}}
    L :: Vector{Bool}
    #Ptr :: Int
end


function ScrollSVD(bmats::Vector{Matrix{T}}) where T
    B = prod(reverse(bmats))
    F = svd(B, alg=LinearAlgebra.QRIteration())
    return ScrollSVD{T}([B], [F], [false])
end


mutable struct ScrollIter
    #每个段里面包含几个tau
    Sector :: Vector{Int}
    #每个tau对应在哪个sector
    Invsec :: Vector{Int}
    #Ptr :: Int
    #Tau :: Int
    #迭代时是否应该增加tau
    Inc :: Bool
    Tot :: Int
end

const ScrollSVDZ = ScrollSVD{ComplexF64}


function Base.push!(ss::ScrollSVD{T}, bmats::Vector{Matrix{T}}) where T
    """
    在ss中推入一个B和F，只能在R上推入（B(t, 0)）
    注意L在右边，R在左边
    """
    B = prod(reverse(bmats))
    push!(ss.B, B)
    F = ss.F[end]
    # B U D Vt = (M) Vt = U' D' Vt' Vt
    #println(size(B), size(F.U), size(F.S))
    M = B * F.U * Diagonal(F.S)
    Fp = svd(M, alg=LinearAlgebra.QRIteration())
    #println(M)
    #println(F.Vt)
    Fn = SVD(Fp.U, Fp.S, Fp.Vt*F.Vt)
    push!(ss.F, Fn)
    #
    push!(ss.L, false)
end


function Base.push!(si::ScrollIter, seclen)
    push!(si.Sector, seclen)
    for idx in Base.OneTo(seclen)
        push!(si.Invsec, length(si.Sector))
    end
    si.Tot += seclen
end


function Base.iterate(si::ScrollIter)
    #si.Tau += si.Inc ? 1 : -1
    if si.Inc
        return (0, 0), 0
    else
        return (si.Tot-1, length(si.Sector)), si.Tot-1
    end
end


function Base.iterate(si::ScrollIter, tau::Int)
    tau += si.Inc ? 1 : -1
    if si.Inc && tau >= si.Tot
        return nothing
    elseif !si.Inc && tau < 0
        return nothing
    end
    isec = si.Invsec[tau+1]
    return (tau, si.Inc ? isec-1 : isec), tau
end



function reset_scroll(ss::ScrollSVD{T}, Lnum) where T
    #重新设置L
    #重新计算所有B
    #对Lnum循环到L
    #从总数循环到L+1
end



"""
将一个B从R放到L
R:
L[1] ... L[i] = false 
F[1] = B[1] ; ; F[i] = B[i]...B[1]
L:
L[i+1] ... L[N] = true
F[i+1] = B[N]...B[i+1] ; ; F[N] = B[N]
===
L[i] = false -> true
B'[i] = prod(reverse(bmats))
F[i] = B[N]...B[i+1] B'[i] = F[i+1] B'[i]
要保证除了B[i]以外的内容没有变化
"""
function scrollR2L(ss::ScrollSVD{T}, bmats::Vector{Matrix{T}}) where T
    ptr = findfirst(ss.L)
    if ptr == 1
        throw(error("does not contain R"))
    end
    #ptr = isnothing(ptr) ? length(ss.L) : ptr-1
    #Bp = prod(reverse(bmats))
    #ss.B[ptr] = Bp
    #
    Bp = prod(reverse(bmats))
    if isnothing(ptr)
        #都在R里面，放一个进去就行
        ss.B[end] = Bp
        ss.F[end] = svd(Bp, alg=LinearAlgebra.QRIteration())
        ss.L[end] = true
    else
        ss.B[ptr-1] = Bp
        #这个F在L中
        F = ss.F[ptr]
        VL, DL, UL = F.U, Diagonal(F.S), F.Vt
        M = DL*UL*Bp
        Fp = svd(M, alg=LinearAlgebra.QRIteration())
        ss.F[ptr-1] = SVD(VL*Fp.U, Fp.S, Fp.Vt)
        ss.L[ptr-1] = true
    end
end



"""
将一个B从R放到L
R:
L[1] ... L[i] = false 
F[1] = B[1] ; ; F[i] = B[i]...B[1]
L:
L[i+1] ... L[N] = true
F[i+1] = B[N]...B[i+1] ; ; F[N] = B[N]
===
L[i+1] = true -> false
B'[i+1] = prod(reverse(bmats))
F[i+1] = B'[i+1]B[i]...B[1] = B'[i+1] F[i]
要保证除了B[i]以外的内容没有变化
"""
function scrollL2R(ss::ScrollSVD{T}, bmats::Vector{Matrix{T}}) where T
    ptr = findfirst(ss.L)
    if isnothing(ptr)
        throw(error("does not contain L"))
    end
    #
    Bp = prod(reverse(bmats))
    if ptr == 1
        #都在L里面，R放进去一个就行
        ss.B[ptr] = Bp
        ss.F[ptr] = svd(Bp, alg=LinearAlgebra.QRIteration())
        ss.L[ptr] = false
    else
        #
        ss.B[ptr] = Bp
        F = ss.F[ptr-1]
        UR, DR, VR = F.U, Diagonal(F.S), F.Vt
        M = Bp*UR*DR
        Fp = svd(M, alg=LinearAlgebra.QRIteration())
        ss.F[ptr] = SVD(Fp.U, Fp.S, Fp.Vt*VR)
        ss.L[ptr] = false
    end
end
